{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import tensorflow as tf\n",
    "import scipy.io\n",
    "\n",
    "import os, re\n",
    "\n",
    "import claude.utils as cu\n",
    "import claude.claudeflow.autoencoder as ae\n",
    "import claude.claudeflow.helper as cfh\n",
    "import claude.claudeflow.training as cft"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 42\n",
    "tf.set_random_seed(seed)\n",
    "np.random.seed(seed)\n",
    "\n",
    "# Parameters\n",
    "# Channel Parameters\n",
    "chParam = cu.AttrDict()\n",
    "chParam.M = 16\n",
    "\n",
    "# Auto-Encoder Parameters\n",
    "aeParam = cu.AttrDict()\n",
    "aeParam.constellationDim = 2\n",
    "aeParam.constellationOrder = chParam.M\n",
    "aeParam.nLayers\t\t= 2\n",
    "aeParam.nHidden\t= 32\n",
    "aeParam.activation  = tf.nn.selu\n",
    "aeParam.dropout\t\t= False\n",
    "aeParam.dtype       = tf.float32\n",
    "\n",
    "# Training Parameters\n",
    "trainingParam = cu.AttrDict()\n",
    "trainingParam.sampleSize\t= 512*chParam.M # Increase for better results (especially if M>16)\n",
    "trainingParam.batchSize \t= 32*chParam.M  # Increase for better results (especially if M>16)\n",
    "trainingParam.learningRate\t= 0.001\n",
    "trainingParam.displayStep\t= 20\n",
    "trainingParam.path\t\t\t= 'results_AWGN_end2end'\n",
    "trainingParam.filename\t\t= 'M{}'.format(chParam.M)\n",
    "trainingParam.earlyStopping = 250\n",
    "trainingParam.iterations = 250\n",
    "\n",
    "# TF constants\n",
    "two = tf.constant(2,aeParam.dtype)\n",
    "DIM = tf.constant(aeParam.constellationDim, aeParam.dtype)\n",
    "PI = tf.constant(np.pi,aeParam.dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: Logging before flag parsing goes to stderr.\n",
      "W0917 07:56:08.645858 139656758703936 deprecation.py:506] From /home/rasmus/.conda/envs/claudeOnline/lib/python3.6/site-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n"
     ]
    }
   ],
   "source": [
    "# Tx Graph     \n",
    "X = tf.placeholder( aeParam.dtype, shape=(None, chParam.M) )\n",
    "enc, enc_seed = ae.encoder(X, aeParam)\n",
    "\n",
    "# Channel Graph\n",
    "sigma2_noise = tf.constant(0.1,aeParam.dtype)\n",
    "noise = tf.sqrt( sigma2_noise )\\\n",
    "            *tf.rsqrt(two)\\\n",
    "            *tf.random_normal(shape=tf.shape(enc),dtype=aeParam.dtype)\n",
    "channel_out = enc + noise\n",
    "\n",
    "dec = ae.decoder(channel_out, aeParam)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "per_ex_loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=X,logits=dec)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "W0917 07:56:09.417371 139656758703936 deprecation.py:323] From /home/rasmus/.conda/envs/claudeOnline/lib/python3.6/site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n"
     ]
    }
   ],
   "source": [
    "# Loss\n",
    "correct_prediction = tf.equal(tf.argmax(X,1), tf.argmax(dec,1))\n",
    "accuracy = tf.reduce_mean(tf.cast(correct_prediction, aeParam.dtype))\n",
    "loss = tf.reduce_mean(per_ex_loss)\n",
    "\n",
    "optimizer = tf.train.AdamOptimizer(learning_rate=trainingParam.learningRate).minimize(loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "W0917 07:56:10.303782 139656758703936 lazy_loader.py:50] \n",
      "The TensorFlow contrib module will not be included in TensorFlow 2.0.\n",
      "For more information, please see:\n",
      "  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n",
      "  * https://github.com/tensorflow/addons\n",
      "  * https://github.com/tensorflow/io (for I/O related ops)\n",
      "If you depend on functionality not listed there, please file an issue.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "metricsDict = {'xentropy':loss, 'accuracy_metric':accuracy}\n",
    "meanMetricOpsDict, updateOps, resetOps = cft.create_mean_metrics(metricsDict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "init = tf.global_variables_initializer()\n",
    "sess = tf.Session()\n",
    "sess.run(init)\n",
    "\n",
    "saver = tf.train.Saver()\n",
    "checkpoint_path = os.path.join(trainingParam.path,'checkpoint',trainingParam.filename,'best')\n",
    "if not os.path.exists(checkpoint_path):\n",
    "    os.makedirs(checkpoint_path)\n",
    "else:\n",
    "    pass\n",
    "#     print(\"Restoring checkpoint...\", flush=True)\n",
    "#     saver.restore(sess=sess,save_path=checkpoint_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoche: 20 - avgLoss: 0.6990486979484558 - avgAcc: 0.7659912109375\n",
      "epoche: 40 - avgLoss: 0.5796283483505249 - avgAcc: 0.782958984375\n",
      "epoche: 60 - avgLoss: 0.5354400873184204 - avgAcc: 0.7965087890625\n",
      "epoche: 80 - avgLoss: 0.5532417297363281 - avgAcc: 0.79736328125\n",
      "epoche: 100 - avgLoss: 0.5578113198280334 - avgAcc: 0.797119140625\n",
      "epoche: 120 - avgLoss: 0.5328306555747986 - avgAcc: 0.800537109375\n",
      "epoche: 140 - avgLoss: 0.5588166117668152 - avgAcc: 0.7838134765625\n",
      "epoche: 160 - avgLoss: 0.5384258031845093 - avgAcc: 0.7989501953125\n",
      "epoche: 180 - avgLoss: 0.5365269780158997 - avgAcc: 0.8017578125\n",
      "epoche: 200 - avgLoss: 0.535810112953186 - avgAcc: 0.7987060546875\n",
      "epoche: 220 - avgLoss: 0.5383657217025757 - avgAcc: 0.8013916015625\n",
      "epoche: 240 - avgLoss: 0.5349321365356445 - avgAcc: 0.80078125\n"
     ]
    }
   ],
   "source": [
    "nBatches = int(trainingParam.sampleSize/trainingParam.batchSize)\n",
    "bestLoss = 10000\n",
    "\n",
    "for epoche in range(1, trainingParam.iterations+1):\n",
    "    sess.run(resetOps)\n",
    "    for batch in range(0,nBatches):\n",
    "        data, _, _ = cu.hotOnes(trainingParam.batchSize,(1,0),chParam.M)\n",
    "        feedDict = {X: data}\n",
    "        sess.run([optimizer, updateOps], feed_dict=feedDict)\n",
    "\n",
    "    [outAvgLoss, outAvgAccuracy] = sess.run([meanMetricOpsDict['xentropy'], meanMetricOpsDict['accuracy_metric']], feed_dict=feedDict)\n",
    "\n",
    "    if outAvgLoss < bestLoss:\n",
    "        bestLoss = outAvgLoss\n",
    "        lastImprovement = epoche\n",
    "        saver.save(sess=sess,save_path=checkpoint_path)\n",
    "\n",
    "    if epoche - lastImprovement > trainingParam.earlyStopping:\n",
    "        print(\"Breaking due to no improvement\")\n",
    "        break;\n",
    "\n",
    "    if epoche%trainingParam.displayStep == 0:\n",
    "        print('epoche: {} - avgLoss: {} - avgAcc: {}'.format(epoche,outAvgLoss,outAvgAccuracy))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "W0917 07:56:25.151086 139656758703936 deprecation.py:323] From /home/rasmus/.conda/envs/claudeOnline/lib/python3.6/site-packages/tensorflow/python/training/saver.py:1276: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use standard file APIs to check for files with this prefix.\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQ4AAAD8CAYAAACGnEoDAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAADmZJREFUeJzt3W+oZPV9x/H3xzWaJyVRdxONuq5LlzS2lCYOZptACK2lKsVNmgRMC9GiLNJKoc8WhLb4pLFPStvYpptUakpQ20CaTWqQqJE86abeDaZGxWSzsPV2Rdc/WEKLutlvH9wx3K5z997fztyZc2beL7jcmTm/Ob8vZ+989vzOnHN+qSokqcVZsy5AUv8YHJKaGRySmhkckpoZHJKaGRySmhkckpoZHJKaGRySmp096wLWsnXr1tqxY8esy5Dm2qFDh16sqm2t7+tscOzYsYOlpaVZlyHNtSRHz+R9DlUkNTM4JDUzOCQ1MzgkNTM4JDUzOCQ1Mzg0M4eOvsJd3z7MoaOvzLoUNerseRyab4eOvsLvfvEgr584yTlnn8WXb9nNlZedN+uytEHucWgmDh55iddPnORkwRsnTnLwyEuzLkkNDA7NxO6dF3DO2WexJfC2s89i984LZl2SGjhU0Uxcedl5fPmW3Rw88hK7d17gMKVnDA7NzJWXnWdg9JRDFUnNDA5JzQwOSc0MDknNDA5JzQwOSc0MDknNDI4O8aIv9YUngHWEF32pT9zj6Agv+lKfGBwd4UVf6hOHKh3hRV/qk4kER5K7gd8CXqiqXxqxPMBfAtcB/wPcVFXfm0Tf88SLvtQXkxqq/ANwzWmWXwvsGv7sBf52Qv1KmoGJBEdVfQd4+TRN9gBfqhUHgXcmuWgSfUuavmkdHL0YeHbV8+Xha/9Pkr1JlpIsHT9+fEqlSWo1reDIiNfqLS9U7a+qQVUNtm1rnkBb0pRMKziWgUtXPb8EODalviVN2LSC4wDwmazYDbxaVc9NqW9JEzapr2PvBT4KbE2yDPwJ8DaAqvo88AArX8UeZuXr2N+bRL+SZmMiwVFVn15neQF/MIm+JM2ep5xLamZwSGpmcEhqZnBIamZwSGpmcEhqZnBIamZwSGpmcEhqZnBIamZwSGpmcEhqZnBIamZwSGpmcEhqZnBIZ2iRJwl3JjfpDCz6JOHucUhnYNEnCTc4pDOw6JOE93qocujoK07SrJlY9EnCexsciz7G1Owt8iThvR2qLPoYU5ql3gbHoo8xpVnq7VBl0ceY0iz1NjhgsceY0iz1dqgiaXYMDknNDA5JzQwOSc0MDk3MIl8tumh6/a2KusMzeRfLRPY4klyT5Jkkh5PsG7H8piTHkzw+/LllEv2qOzyTd7GMvceRZAtwF/AbwDLwWJIDVfXUKU3vr6rbxu1P3fTmmbxvnDjpmbwLYBJDlauAw1V1BCDJfcAe4NTg0BzzTN7FMonguBh4dtXzZeCDI9p9IslHgB8Cf1RVz45oox7zTN7FMYljHBnxWp3y/OvAjqr6ZeAh4J6RK0r2JllKsnT8+PEJlCZpM0wiOJaBS1c9vwQ4trpBVb1UVa8Nn34BuHLUiqpqf1UNqmqwbdu2CZQmaTNMIjgeA3YluTzJOcANwIHVDZJctOrp9cDTE+hX0oyMfYyjqk4kuQ14ENgC3F1VTya5A1iqqgPAHya5HjgBvAzcNG6/kmYnVacejuiGwWBQS0tLsy5DmmtJDlXVoPV9nnIuqZnBIamZwSGpmcEhqZnBIamZwSGpmcHRId4IR33hjXw6whvhqE/c4+gIb4SjPjE4OsIpLdUnDlU6whvhqE8Mjg7xRjjqC4cqkpoZHJKaGRySmhkckpoZHJKaGRySmhkckpoZHJKaGRySmhkckpoZHJKaGRySmhkckpoZHJKaGRySmhkckpoZHJKaGRySmhkckppNJDiSXJPkmSSHk+wbsfzcJPcPl383yY5J9CtpNsYOjiRbgLuAa4ErgE8nueKUZjcDr1TVzwN/Adw5br/SounSTH+TuMv5VcDhqjoCkOQ+YA/w1Ko2e4A/HT7+CvC5JKmqmkD/WnCHjr4y99NKdG2mv0kEx8XAs6ueLwMfXKtNVZ1I8ipwAfDiBPrXhPXpg9i1D9RmGTXTX9+DIyNeO3VPYiNtSLIX2Auwffv28StTs759ELv2gdosb87098aJk52Y6W8SwbEMXLrq+SXAsTXaLCc5G3gH8PKpK6qq/cB+gMFg4DBmBvr2QezaB2qzdG2mv0kEx2PAriSXA/8F3AD8ziltDgA3Av8GfBJ4xOMb3dS3D2LXPlCbqUsz/Y0dHMNjFrcBDwJbgLur6skkdwBLVXUA+HvgH5McZmVP44Zx+4V+jcX7oo8fxC59oBZFuvof/2AwqKWlpTWX920sLnVRkkNVNWh9X2/PHB01Fpc0Hb0NjjfH4ltCL8bi0jyZxMHRmejjWFyaF70NDvCgmDQrvR2qSJodg0NSM4NDUjODQ1Izg0NSM4NDUjODQ1Izg0NSM4NDUjODQ1Izg0NSM4NDUjODQ1Izg0NSM4NDUjODQ1Izg0NSM4NDUjODo4O6NCu5NEqv7zk6j5wvRn3gHkfHOF+M+sDg6Bjni1EfOFTpGOeLUR8YHB3kfDHqOocqkpoZHJKaGRySmhkckpqNFRxJzk/yrSQ/Gv4eeUQvyU+TPD78OTBOn5Jmb9w9jn3Aw1W1C3h4+HyU/62qXxn+XD9mn5JmbNzg2APcM3x8D/CxMdcnqQfGDY53V9VzAMPf71qj3duTLCU5mGTNcEmyd9hu6fjx42OWpi7wgr35tO4JYEkeAi4csej2hn62V9WxJDuBR5I8UVU/PrVRVe0H9gMMBoNqWL86yAv25te6wVFVV6+1LMnzSS6qqueSXAS8sMY6jg1/H0nyKPB+4C3Bofky6oI9g2M+jDtUOQDcOHx8I/C1UxskOS/JucPHW4EPA0+N2a96wAv25te416p8FvinJDcD/wl8CiDJALi1qm4B3gf8XZKTrATVZ6vK4FgAXrA3v1LVzUMJg8GglpaWZl2GNNeSHKqqQev7PHNUUjODQ1Izg0NSM4NDUjODQxrTIp4d660DpTEs6tmx7nFIY1jU6SwMDmkMi3p2rEMVaQyLenaswSGNaRGns3CoIqmZwSGpmcEhqZnBIamZwSGpmcEhqZnBIanZ3AbHIl54JE3LXJ4AtqgXHknTMpd7HIt64ZE0LXMZHIt64ZE0LXM5VFnUC4+kaZnL4IDFvPBImpa5HKpI2lwGh6RmBoekZgaHpGYGh6RmBoekZgaHpGYGh6RmYwVHkk8leTLJySSD07S7JskzSQ4n2TdOn5Jmb9w9jh8Avw18Z60GSbYAdwHXAlcAn05yxZj9SpqhsU45r6qnAZKcrtlVwOGqOjJsex+wB3hqnL4lzc40jnFcDDy76vny8DVJPbXuHkeSh4ALRyy6vaq+toE+Ru2O1Bp97QX2Amzfvn0Dq5Y0C+sGR1VdPWYfy8Clq55fAhxbo6/9wH6AwWAwMlwkzd40hiqPAbuSXJ7kHOAG4MAU+p1b3k9VszbWwdEkHwf+GtgG/GuSx6vqN5O8B/hiVV1XVSeS3AY8CGwB7q6qJ8eufEF5P1V1wbjfqnwV+OqI148B1616/gDwwDh9acWo+6kaHJo2zxztGe+nqi6Y21sHzivvp6ouMDh6yPupatYcqkhqZnBIamZwSGpmcEhqZnBIamZwSGpmcEhqZnCos7yYr7s8AUyd5MV83eYehzpp1MV86g6DQ53kxXzd5lBFneTFfN1mcKizvJivuxyqSGpmcEhqZnBIamZwSGpmcEhqZnBIapaqbk6YluQ4cHSMVWwFXpxQOZuh6/VB92u0vvG9t6p+rvVNnT2Po6q2jfP+JEtVNZhUPZPW9fqg+zVa3/iSLJ3J+xyqSGpmcEhqNs/BsX/WBayj6/VB92u0vvGdUY2dPTgqqbvmeY9D0iaZm+BI8qkkTyY5mWTNI9lJrknyTJLDSfZNsb7zk3wryY+Gv0de9pnkp0keH/4cmEJdp90eSc5Ncv9w+XeT7Njsms6gxpuSHF+13W6ZYm13J3khyQ/WWJ4kfzWs/T+SfGBatTXU+NEkr67afn+87kqrai5+gPcB7wUeBQZrtNkC/BjYCZwDfB+4Ykr1/Tmwb/h4H3DnGu1+MsVttu72AH4f+Pzw8Q3A/VP+d91IjTcBn5vR391HgA8AP1hj+XXAN4EAu4HvdrDGjwLfaFnn3OxxVNXTVfXMOs2uAg5X1ZGqeh24D9iz+dXBsJ97ho/vAT42pX5PZyPbY3XdXwF+PUk6VuPMVNV3gJdP02QP8KVacRB4Z5KLplPdig3U2GxugmODLgaeXfV8efjaNLy7qp4DGP5+1xrt3p5kKcnBJJsdLhvZHj9rU1UngFeBad7Hb6P/Zp8YDgW+kuTS6ZS2IbP8m2vxq0m+n+SbSX5xvcadPXN0lCQPAReOWHR7VX1tI6sY8drEvlY6XX0Nq9leVceS7AQeSfJEVf14MhW+xUa2x6Zusw3YSP9fB+6tqteS3MrKHtKvbXplGzPr7bcR3wMuq6qfJLkO+Bdg1+ne0KvgqKqrx1zFMrD6f6NLgGNjrvNnTldfkueTXFRVzw13VV9YYx3Hhr+PJHkUeD8rY/zNsJHt8Wab5SRnA+9gwru961i3xqpafQv0LwB3TqGujdrUv7lJqKr/XvX4gSR/k2RrVa15nc2iDVUeA3YluTzJOawc7Nv0by6GDgA3Dh/fCLxlDynJeUnOHT7eCnwYeGoTa9rI9lhd9yeBR2p4RG1K1q3xlGMG1wNPT7G+9RwAPjP8dmU38OqbQ9auSHLhm8etklzFSi6cfj6KWRyJ3qQjxx9nJd1fA54HHhy+/h7ggVXtrgN+yMr/4rdPsb4LgIeBHw1/nz98fQB8cfj4Q8ATrHxz8ARw8xTqesv2AO4Arh8+fjvwz8Bh4N+BnTP4t12vxj8Dnhxut28DvzDF2u4FngPeGP793QzcCtw6XB7grmHtT7DGN34zrvG2VdvvIPCh9dbpmaOSmi3aUEXSBBgckpoZHJKaGRySmhkckpoZHJKaGRySmhkckpr9H5Q7ShVTwkJLAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "saver.restore(sess=sess,save_path=checkpoint_path)\n",
    "data, idx, xseed = cu.hotOnes(trainingParam.batchSize,(1,0),chParam.M)\n",
    "feedDict = {X: xseed}\n",
    "pred_const = sess.run(enc_seed, feed_dict=feedDict)\n",
    "plt.plot(pred_const[:,0],pred_const[:,1],'.')\n",
    "plt.axis('square');"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
